Skip to content

sage-attention: avoid torch custom_op for sm100 and add benchmark #452

Merged
drbh merged 1 commit into
mainfrom
update-sage-attn-ops
May 19, 2026
Merged

sage-attention: avoid torch custom_op for sm100 and add benchmark #452
drbh merged 1 commit into
mainfrom
update-sage-attn-ops

Conversation

@drbh
Copy link
Copy Markdown
Collaborator

@drbh drbh commented Mar 6, 2026

This PR

  • avoids the @torch.library.custom_op in sm100_compile.py similar to the sm90_compile.py file
  • adds a benchmark file
  • adds a readme example for simple validation script

@drbh drbh requested a review from danieldk as a code owner March 6, 2026 18:15
@drbh
Copy link
Copy Markdown
Collaborator Author

drbh commented Mar 6, 2026

note these changes may resolve the double registration issue seen on B200's

test with

# /// script
# dependencies = [
#   "numpy",
#   "torch",
#   "kernels",
# ]
# ///
import torch
from kernels import get_kernel

torch.manual_seed(42)
sage_attention = get_kernel("drbh/sage-attn-test", version=2)

device = "cuda"
B, H, L, D = 1, 8, 256, 64
q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)

out = sage_attention.sageattn3_blackwell(q, k, v)
print(f"sageattn output shape: {out.shape}")

@drbh drbh merged commit 7535f60 into main May 19, 2026
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants